{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Cross-validation for 4 different objective function\n",
    "\n",
    "1. CaDRReS\n",
    "3. CaDRReS + no bp + ciu (ciu = drug-sample weight based on max_conc)\n",
    "3. CaDRReS + no bp + ciu + du (du = sample weight based on cancer type)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Read gene expression file and calculate kernel features\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-23T06:58:44.565077Z",
     "start_time": "2020-06-23T06:58:39.497657Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/miniconda3/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
      "/home/ubuntu/miniconda3/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
      "/home/ubuntu/miniconda3/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
      "/home/ubuntu/miniconda3/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
      "/home/ubuntu/miniconda3/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
      "/home/ubuntu/miniconda3/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n",
      "/home/ubuntu/miniconda3/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
      "/home/ubuntu/miniconda3/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
      "/home/ubuntu/miniconda3/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
      "/home/ubuntu/miniconda3/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
      "/home/ubuntu/miniconda3/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
      "/home/ubuntu/miniconda3/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n"
     ]
    }
   ],
   "source": [
    "import sys, os, pickle\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "np.set_printoptions(precision=2)\n",
    "from collections import Counter\n",
    "import importlib\n",
    "\n",
    "import matplotlib as mpl\n",
    "mpl.rcParams['figure.dpi']= 300\n",
    "mpl.rc(\"savefig\", dpi=300)\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from scipy import stats\n",
    "\n",
    "scriptpath = '..'\n",
    "sys.path.append(os.path.abspath(scriptpath))\n",
    "\n",
    "from cadrres import pp, model, evaluation, utility"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-23T06:58:44.571002Z",
     "start_time": "2020-06-23T06:58:44.566648Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'1.14.0'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "tf.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Read cell line info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-23T06:58:44.732855Z",
     "start_time": "2020-06-23T06:58:44.572713Z"
    }
   },
   "outputs": [],
   "source": [
    "gdsc_sample_df = pd.read_csv('../data/GDSC/GDSC_tissue_info.csv', index_col=0)\n",
    "gdsc_sample_df.index = gdsc_sample_df.index.astype(str)\n",
    "\n",
    "gdsc_obs_df = pd.read_csv('../data/GDSC/gdsc_all_abs_ic50_bayesian_sigmoid_only9dosages.csv', index_col=0)\n",
    "gdsc_obs_df.index = gdsc_obs_df.index.astype(str)\n",
    "\n",
    "gdsc_sample_list = gdsc_obs_df.index.astype(str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-23T06:58:44.739738Z",
     "start_time": "2020-06-23T06:58:44.734440Z"
    }
   },
   "outputs": [],
   "source": [
    "indication_count_df = gdsc_sample_df.groupby(['TCGA_CLASS']).size().sort_values(ascending=False).drop('UNCLASSIFIED')\n",
    "selected_indications = indication_count_df.index[indication_count_df >= 35]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Read drug info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-23T06:58:44.769569Z",
     "start_time": "2020-06-23T06:58:44.740974Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(226, 27)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# gdsc_drug_df = pd.read_csv('../preprocessed_data/GDSC/hn_drug_stat.csv', index_col=0)\n",
    "gdsc_drug_df = pd.read_csv('../preprocessed_data/GDSC/drug_stat.csv', index_col=0)\n",
    "gdsc_drug_df.index = gdsc_drug_df.index.astype(str)\n",
    "\n",
    "gdsc_drug_list = gdsc_drug_df.index\n",
    "gdsc_drug_df.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Read gene expression and normalization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-23T06:58:52.616002Z",
     "start_time": "2020-06-23T06:58:44.770752Z"
    }
   },
   "outputs": [],
   "source": [
    "gdsc_log2_exp_df = pd.read_csv('../data/GDSC/GDSC_exp.tsv', sep='\\t', index_col=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-23T06:58:52.623588Z",
     "start_time": "2020-06-23T06:58:52.617345Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(17737, 1018)\n",
      "17420\n",
      "1018\n"
     ]
    }
   ],
   "source": [
    "print (gdsc_log2_exp_df.shape) \n",
    "print (len(gdsc_log2_exp_df.index.unique())) # included Nan\n",
    "print (len(gdsc_log2_exp_df.columns.unique()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For genes with multiple probes, calculate mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-23T06:58:52.984470Z",
     "start_time": "2020-06-23T06:58:52.624881Z"
    }
   },
   "outputs": [],
   "source": [
    "gdsc_log2_exp_df = gdsc_log2_exp_df.groupby(gdsc_log2_exp_df.index).mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Calculate log2 fold-change based on mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-23T06:58:58.125324Z",
     "start_time": "2020-06-23T06:58:52.987816Z"
    }
   },
   "outputs": [],
   "source": [
    "gdsc_log2_mean_fc_exp_df, gdsc_mean_exp_df = pp.gexp.normalize_log2_mean_fc(gdsc_log2_exp_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Read essential genes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-23T06:58:58.139988Z",
     "start_time": "2020-06-23T06:58:58.127005Z"
    }
   },
   "outputs": [],
   "source": [
    "ess_gene_list = utility.get_gene_list('../data/essential_genes.txt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sample with both expression and response data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-23T06:58:58.145723Z",
     "start_time": "2020-06-23T06:58:58.141255Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "985"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gdsc_sample_list = np.array([s for s in gdsc_sample_list if s in gdsc_log2_exp_df.columns])\n",
    "len(gdsc_sample_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-23T06:58:58.152707Z",
     "start_time": "2020-06-23T06:58:58.147009Z"
    }
   },
   "outputs": [],
   "source": [
    "gdsc_sample_df = gdsc_sample_df.loc[gdsc_sample_list]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-23T06:58:58.168748Z",
     "start_time": "2020-06-23T06:58:58.153960Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SCLC (60,)\n",
      "LUAD (63,)\n",
      "SKCM (52,)\n",
      "BRCA (49,)\n",
      "COREAD (48,)\n",
      "HNSC (42,)\n",
      "GBM (35,)\n",
      "ESCA (35,)\n"
     ]
    }
   ],
   "source": [
    "gdsc_sample_dict = {}\n",
    "gdsc_obs_df_dict = {}\n",
    "for i in selected_indications:\n",
    "    gdsc_sample_dict[i] = gdsc_sample_df[gdsc_sample_df['TCGA_CLASS']==i].index\n",
    "    gdsc_obs_df_dict[i] = gdsc_obs_df.loc[gdsc_sample_dict[i]]\n",
    "    print (i, gdsc_sample_dict[i].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-23T06:58:58.237780Z",
     "start_time": "2020-06-23T06:58:58.170091Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((17419, 985), (985, 226))"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gdsc_log2_mean_fc_exp_df = gdsc_log2_mean_fc_exp_df[gdsc_sample_list]\n",
    "gdsc_obs_df = gdsc_obs_df.loc[gdsc_sample_list, gdsc_drug_list]\n",
    "gdsc_drug_df = gdsc_drug_df.loc[gdsc_drug_list]\n",
    "\n",
    "gdsc_log2_mean_fc_exp_df.shape, gdsc_obs_df.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Calculate kernel feature \n",
    "\n",
    "Based on all 985 GDSC samples with gene expression profiles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-23T06:58:58.241649Z",
     "start_time": "2020-06-23T06:58:58.239213Z"
    },
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# kernel_feature_df = pp.gexp.calculate_kernel_feature(gdsc_log2_mean_fc_exp_df, gdsc_log2_mean_fc_exp_df, ess_gene_list).loc[gdsc_sample_list]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-23T06:58:59.050224Z",
     "start_time": "2020-06-23T06:58:58.243090Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>1240121</th>\n",
       "      <th>1240122</th>\n",
       "      <th>1240123</th>\n",
       "      <th>1240124</th>\n",
       "      <th>1240125</th>\n",
       "      <th>1240127</th>\n",
       "      <th>1240128</th>\n",
       "      <th>1240129</th>\n",
       "      <th>1240130</th>\n",
       "      <th>1240131</th>\n",
       "      <th>...</th>\n",
       "      <th>949175</th>\n",
       "      <th>949176</th>\n",
       "      <th>949177</th>\n",
       "      <th>949178</th>\n",
       "      <th>949179</th>\n",
       "      <th>971773</th>\n",
       "      <th>971774</th>\n",
       "      <th>971777</th>\n",
       "      <th>998184</th>\n",
       "      <th>998189</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1240121</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.200762</td>\n",
       "      <td>-0.097257</td>\n",
       "      <td>0.079455</td>\n",
       "      <td>-0.080807</td>\n",
       "      <td>-0.107964</td>\n",
       "      <td>-0.058302</td>\n",
       "      <td>0.079915</td>\n",
       "      <td>0.063199</td>\n",
       "      <td>0.035671</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.129215</td>\n",
       "      <td>-0.179337</td>\n",
       "      <td>-0.095300</td>\n",
       "      <td>-0.112817</td>\n",
       "      <td>-0.186527</td>\n",
       "      <td>-0.088457</td>\n",
       "      <td>-0.143004</td>\n",
       "      <td>-0.189747</td>\n",
       "      <td>-0.259590</td>\n",
       "      <td>-0.054617</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1240122</th>\n",
       "      <td>0.200762</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.193214</td>\n",
       "      <td>-0.049567</td>\n",
       "      <td>-0.180749</td>\n",
       "      <td>0.187601</td>\n",
       "      <td>0.042315</td>\n",
       "      <td>0.171160</td>\n",
       "      <td>-0.049354</td>\n",
       "      <td>-0.061332</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.008915</td>\n",
       "      <td>0.042224</td>\n",
       "      <td>0.080204</td>\n",
       "      <td>0.052032</td>\n",
       "      <td>-0.091817</td>\n",
       "      <td>0.007112</td>\n",
       "      <td>0.046598</td>\n",
       "      <td>0.099549</td>\n",
       "      <td>-0.010853</td>\n",
       "      <td>-0.037156</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1240123</th>\n",
       "      <td>-0.097257</td>\n",
       "      <td>0.193214</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.165331</td>\n",
       "      <td>-0.079250</td>\n",
       "      <td>0.187299</td>\n",
       "      <td>-0.092017</td>\n",
       "      <td>-0.022646</td>\n",
       "      <td>0.057621</td>\n",
       "      <td>-0.160944</td>\n",
       "      <td>...</td>\n",
       "      <td>0.082774</td>\n",
       "      <td>0.044117</td>\n",
       "      <td>0.080209</td>\n",
       "      <td>0.092592</td>\n",
       "      <td>0.025462</td>\n",
       "      <td>-0.139113</td>\n",
       "      <td>0.126919</td>\n",
       "      <td>0.068192</td>\n",
       "      <td>0.098840</td>\n",
       "      <td>0.289321</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1240124</th>\n",
       "      <td>0.079455</td>\n",
       "      <td>-0.049567</td>\n",
       "      <td>0.165331</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.213386</td>\n",
       "      <td>-0.134121</td>\n",
       "      <td>-0.063749</td>\n",
       "      <td>-0.069065</td>\n",
       "      <td>0.121054</td>\n",
       "      <td>0.002488</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.271613</td>\n",
       "      <td>-0.177998</td>\n",
       "      <td>-0.222713</td>\n",
       "      <td>-0.215761</td>\n",
       "      <td>-0.092428</td>\n",
       "      <td>-0.068879</td>\n",
       "      <td>-0.154265</td>\n",
       "      <td>-0.092515</td>\n",
       "      <td>-0.003545</td>\n",
       "      <td>0.374791</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1240125</th>\n",
       "      <td>-0.080807</td>\n",
       "      <td>-0.180749</td>\n",
       "      <td>-0.079250</td>\n",
       "      <td>0.213386</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>-0.048241</td>\n",
       "      <td>-0.110779</td>\n",
       "      <td>0.039282</td>\n",
       "      <td>0.270578</td>\n",
       "      <td>0.084228</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.176282</td>\n",
       "      <td>-0.160407</td>\n",
       "      <td>-0.128728</td>\n",
       "      <td>-0.236164</td>\n",
       "      <td>0.028765</td>\n",
       "      <td>0.033600</td>\n",
       "      <td>-0.054744</td>\n",
       "      <td>0.132978</td>\n",
       "      <td>-0.040489</td>\n",
       "      <td>0.017542</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 985 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "          1240121   1240122   1240123   1240124   1240125   1240127   1240128  \\\n",
       "1240121  1.000000  0.200762 -0.097257  0.079455 -0.080807 -0.107964 -0.058302   \n",
       "1240122  0.200762  1.000000  0.193214 -0.049567 -0.180749  0.187601  0.042315   \n",
       "1240123 -0.097257  0.193214  1.000000  0.165331 -0.079250  0.187299 -0.092017   \n",
       "1240124  0.079455 -0.049567  0.165331  1.000000  0.213386 -0.134121 -0.063749   \n",
       "1240125 -0.080807 -0.180749 -0.079250  0.213386  1.000000 -0.048241 -0.110779   \n",
       "\n",
       "          1240129   1240130   1240131  ...    949175    949176    949177  \\\n",
       "1240121  0.079915  0.063199  0.035671  ... -0.129215 -0.179337 -0.095300   \n",
       "1240122  0.171160 -0.049354 -0.061332  ... -0.008915  0.042224  0.080204   \n",
       "1240123 -0.022646  0.057621 -0.160944  ...  0.082774  0.044117  0.080209   \n",
       "1240124 -0.069065  0.121054  0.002488  ... -0.271613 -0.177998 -0.222713   \n",
       "1240125  0.039282  0.270578  0.084228  ... -0.176282 -0.160407 -0.128728   \n",
       "\n",
       "           949178    949179    971773    971774    971777    998184    998189  \n",
       "1240121 -0.112817 -0.186527 -0.088457 -0.143004 -0.189747 -0.259590 -0.054617  \n",
       "1240122  0.052032 -0.091817  0.007112  0.046598  0.099549 -0.010853 -0.037156  \n",
       "1240123  0.092592  0.025462 -0.139113  0.126919  0.068192  0.098840  0.289321  \n",
       "1240124 -0.215761 -0.092428 -0.068879 -0.154265 -0.092515 -0.003545  0.374791  \n",
       "1240125 -0.236164  0.028765  0.033600 -0.054744  0.132978 -0.040489  0.017542  \n",
       "\n",
       "[5 rows x 985 columns]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# kernel_feature_df.to_csv('../preprocessed_data/GDSC/kernel_features.csv')\n",
    "# kernel_feature_df.head()\n",
    "\n",
    "kernel_feature_df = pd.read_csv('../preprocessed_data/GDSC/kernel_features.csv', index_col=0)\n",
    "kernel_feature_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cross validation (5-fold)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create 5-fold datasets\n",
    "\n",
    "For selected indications, make sure that it split equally. For other indication, split randomly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-23T06:58:59.053874Z",
     "start_time": "2020-06-23T06:58:59.051558Z"
    }
   },
   "outputs": [],
   "source": [
    "# from sklearn.model_selection import KFold\n",
    "# from collections import defaultdict\n",
    "# kf = KFold(n_splits=5, random_state=0, shuffle=True)\n",
    "\n",
    "# k_train_sample_dict = defaultdict(list)\n",
    "# k_val_sample_dict = defaultdict(list)\n",
    "\n",
    "# for i in selected_indications:\n",
    "#     k = 1\n",
    "#     for train_index, val_index in kf.split(gdsc_sample_dict[i]):\n",
    "#         k_train_sample_dict[k] += list(gdsc_sample_dict[i][train_index].astype(str))\n",
    "#         k_val_sample_dict[k] += list(gdsc_sample_dict[i][val_index].astype(str))\n",
    "#         k += 1\n",
    "\n",
    "# other_sample_list = gdsc_sample_df[~gdsc_sample_df['TCGA_CLASS'].isin(selected_indications)].index.values\n",
    "\n",
    "# k = 1\n",
    "# for train_index, val_index in kf.split(other_sample_list):\n",
    "#     k_train_sample_dict[k] += list(other_sample_list[train_index].astype(str))\n",
    "#     k_val_sample_dict[k] += list(other_sample_list[val_index].astype(str))\n",
    "#     k += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-23T06:58:59.058983Z",
     "start_time": "2020-06-23T06:58:59.055175Z"
    }
   },
   "outputs": [],
   "source": [
    "# for k in range(1, 5+1):\n",
    "#     train_samples = k_train_sample_dict[k]\n",
    "#     val_samples = k_val_sample_dict[k]\n",
    "    \n",
    "#     cv_dict = {}\n",
    "    \n",
    "#     # kernel feature based only on training samples\n",
    "#     cv_dict['X_train'] = kernel_feature_df.loc[train_samples, train_samples]\n",
    "#     cv_dict['X_test'] = kernel_feature_df.loc[val_samples, train_samples]\n",
    "    \n",
    "#     # log2 fold-change features\n",
    "#     cv_dict['X_fc_train'] = gdsc_log2_mean_fc_exp_df.T.loc[train_samples]\n",
    "#     cv_dict['X_fc_test'] = gdsc_log2_mean_fc_exp_df.T.loc[val_samples]\n",
    "    \n",
    "#     # observed drug response\n",
    "#     cv_dict['Y_train'] = gdsc_obs_df.loc[train_samples]\n",
    "#     cv_dict['Y_test'] = gdsc_obs_df.loc[val_samples]\n",
    "    \n",
    "#     for name in ['X_train', 'X_test', 'Y_train', 'Y_test', 'X_fc_train', 'X_fc_test']:\n",
    "#         cv_dict[name].to_csv('../preprocessed_data/GDSC/cv_data/{}_5f_{}.csv'.format(name, k))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train and predict the validation set\n",
    "\n",
    "- train_model for 'cadrres', 'cadrres-wo-sample-bias'\n",
    "- train_model_logistic_weight (with d_u and c_iu; no sample bias implementation)\n",
    "    - cadrres-wo-sample-bias-weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-23T06:58:59.066297Z",
     "start_time": "2020-06-23T06:58:59.060270Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<module 'cadrres.utility' from '/mnt/volume1/Dropbox/Research/2019_drug_response_heterogeneity/CaDRReS_depository/cadrres/utility.py'>"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "importlib.reload(pp)\n",
    "importlib.reload(model)\n",
    "importlib.reload(evaluation)\n",
    "importlib.reload(utility)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train for non indication-specific"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-23T10:23:39.936936Z",
     "start_time": "2020-06-23T07:00:56.257437Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fold # 1\n",
      "Initializing the model ...\n",
      "WARNING:tensorflow:From /mnt/volume1/Dropbox/Research/2019_drug_response_heterogeneity/CaDRReS_depository/cadrres/model.py:101: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n",
      "\n",
      "WARNING:tensorflow:From /mnt/volume1/Dropbox/Research/2019_drug_response_heterogeneity/CaDRReS_depository/cadrres/model.py:122: The name tf.truncated_normal is deprecated. Please use tf.random.truncated_normal instead.\n",
      "\n",
      "WARNING:tensorflow:From /mnt/volume1/Dropbox/Research/2019_drug_response_heterogeneity/CaDRReS_depository/cadrres/model.py:124: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.\n",
      "\n",
      "Train: 143852 out of 177410\n",
      "Starting model training ...\n",
      "WARNING:tensorflow:From /mnt/volume1/Dropbox/Research/2019_drug_response_heterogeneity/CaDRReS_depository/cadrres/model.py:320: The name tf.train.GradientDescentOptimizer is deprecated. Please use tf.compat.v1.train.GradientDescentOptimizer instead.\n",
      "\n",
      "WARNING:tensorflow:From /mnt/volume1/Dropbox/Research/2019_drug_response_heterogeneity/CaDRReS_depository/cadrres/model.py:322: The name tf.summary.scalar is deprecated. Please use tf.compat.v1.summary.scalar instead.\n",
      "\n",
      "WARNING:tensorflow:From /mnt/volume1/Dropbox/Research/2019_drug_response_heterogeneity/CaDRReS_depository/cadrres/model.py:324: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.\n",
      "\n",
      "WARNING:tensorflow:From /mnt/volume1/Dropbox/Research/2019_drug_response_heterogeneity/CaDRReS_depository/cadrres/model.py:330: The name tf.global_variables_initializer is deprecated. Please use tf.compat.v1.global_variables_initializer instead.\n",
      "\n",
      "MSE train at step 0: 18.348 (0.01m)\n",
      "MSE train at step 5000: 11.737 (0.91m)\n",
      "MSE train at step 10000: 8.471 (1.82m)\n",
      "MSE train at step 15000: 6.514 (2.73m)\n",
      "MSE train at step 20000: 5.351 (3.64m)\n",
      "MSE train at step 25000: 4.638 (4.55m)\n",
      "MSE train at step 30000: 4.186 (5.45m)\n",
      "MSE train at step 35000: 3.888 (6.36m)\n",
      "MSE train at step 40000: 3.686 (7.26m)\n",
      "MSE train at step 45000: 3.543 (8.17m)\n",
      "MSE train at step 50000: 3.438 (9.08m)\n",
      "MSE train at step 55000: 3.358 (9.98m)\n",
      "MSE train at step 60000: 3.294 (10.89m)\n",
      "MSE train at step 65000: 3.242 (11.80m)\n",
      "MSE train at step 70000: 3.198 (12.70m)\n",
      "MSE train at step 75000: 3.161 (13.61m)\n",
      "MSE train at step 80000: 3.128 (14.52m)\n",
      "MSE train at step 85000: 3.100 (15.42m)\n",
      "MSE train at step 90000: 3.074 (16.33m)\n",
      "MSE train at step 95000: 3.051 (17.23m)\n",
      "Saving model parameters and predictions ...\n",
      "DONE\n",
      "Fold # 2\n",
      "Initializing the model ...\n",
      "Train: 143196 out of 177636\n",
      "Starting model training ...\n",
      "MSE train at step 0: 18.163 (0.00m)\n",
      "MSE train at step 5000: 11.634 (0.91m)\n",
      "MSE train at step 10000: 8.372 (1.81m)\n",
      "MSE train at step 15000: 6.434 (2.71m)\n",
      "MSE train at step 20000: 5.281 (3.62m)\n",
      "MSE train at step 25000: 4.574 (4.52m)\n",
      "MSE train at step 30000: 4.126 (5.43m)\n",
      "MSE train at step 35000: 3.831 (6.33m)\n",
      "MSE train at step 40000: 3.630 (7.24m)\n",
      "MSE train at step 45000: 3.487 (8.15m)\n",
      "MSE train at step 50000: 3.381 (9.05m)\n",
      "MSE train at step 55000: 3.301 (9.95m)\n",
      "MSE train at step 60000: 3.238 (10.86m)\n",
      "MSE train at step 65000: 3.187 (11.76m)\n",
      "MSE train at step 70000: 3.144 (12.67m)\n",
      "MSE train at step 75000: 3.108 (13.57m)\n",
      "MSE train at step 80000: 3.077 (14.48m)\n",
      "MSE train at step 85000: 3.050 (15.38m)\n",
      "MSE train at step 90000: 3.027 (16.29m)\n",
      "MSE train at step 95000: 3.005 (17.19m)\n",
      "Saving model parameters and predictions ...\n",
      "DONE\n",
      "Fold # 3\n",
      "Initializing the model ...\n",
      "Train: 144641 out of 178088\n",
      "Starting model training ...\n",
      "MSE train at step 0: 18.336 (0.00m)\n",
      "MSE train at step 5000: 11.805 (0.93m)\n",
      "MSE train at step 10000: 8.499 (1.84m)\n",
      "MSE train at step 15000: 6.537 (2.76m)\n",
      "MSE train at step 20000: 5.369 (3.68m)\n",
      "MSE train at step 25000: 4.651 (4.60m)\n",
      "MSE train at step 30000: 4.196 (5.51m)\n",
      "MSE train at step 35000: 3.898 (6.43m)\n",
      "MSE train at step 40000: 3.697 (7.35m)\n",
      "MSE train at step 45000: 3.555 (8.26m)\n",
      "MSE train at step 50000: 3.451 (9.18m)\n",
      "MSE train at step 55000: 3.371 (10.10m)\n",
      "MSE train at step 60000: 3.308 (11.02m)\n",
      "MSE train at step 65000: 3.257 (11.93m)\n",
      "MSE train at step 70000: 3.214 (12.85m)\n",
      "MSE train at step 75000: 3.176 (13.77m)\n",
      "MSE train at step 80000: 3.144 (14.68m)\n",
      "MSE train at step 85000: 3.115 (15.60m)\n",
      "MSE train at step 90000: 3.090 (16.52m)\n",
      "MSE train at step 95000: 3.067 (17.44m)\n",
      "Saving model parameters and predictions ...\n",
      "DONE\n",
      "Fold # 4\n",
      "Initializing the model ...\n",
      "Train: 144563 out of 178540\n",
      "Starting model training ...\n",
      "MSE train at step 0: 18.298 (0.00m)\n",
      "MSE train at step 5000: 11.700 (0.89m)\n",
      "MSE train at step 10000: 8.430 (1.78m)\n",
      "MSE train at step 15000: 6.488 (2.67m)\n",
      "MSE train at step 20000: 5.330 (3.56m)\n",
      "MSE train at step 25000: 4.616 (4.45m)\n",
      "MSE train at step 30000: 4.163 (5.34m)\n",
      "MSE train at step 35000: 3.865 (6.23m)\n",
      "MSE train at step 40000: 3.663 (7.12m)\n",
      "MSE train at step 45000: 3.519 (8.01m)\n",
      "MSE train at step 50000: 3.414 (8.89m)\n",
      "MSE train at step 55000: 3.333 (9.79m)\n",
      "MSE train at step 60000: 3.268 (10.67m)\n",
      "MSE train at step 65000: 3.216 (11.56m)\n",
      "MSE train at step 70000: 3.172 (12.45m)\n",
      "MSE train at step 75000: 3.135 (13.34m)\n",
      "MSE train at step 80000: 3.103 (14.23m)\n",
      "MSE train at step 85000: 3.074 (15.12m)\n",
      "MSE train at step 90000: 3.050 (16.01m)\n",
      "MSE train at step 95000: 3.028 (16.90m)\n",
      "Saving model parameters and predictions ...\n",
      "DONE\n",
      "Fold # 5\n",
      "Initializing the model ...\n",
      "Train: 144704 out of 178766\n",
      "Starting model training ...\n",
      "MSE train at step 0: 18.221 (0.00m)\n",
      "MSE train at step 5000: 11.704 (0.89m)\n",
      "MSE train at step 10000: 8.438 (1.78m)\n",
      "MSE train at step 15000: 6.497 (2.67m)\n",
      "MSE train at step 20000: 5.347 (3.56m)\n",
      "MSE train at step 25000: 4.644 (4.45m)\n",
      "MSE train at step 30000: 4.197 (5.34m)\n",
      "MSE train at step 35000: 3.904 (6.23m)\n",
      "MSE train at step 40000: 3.703 (7.12m)\n",
      "MSE train at step 45000: 3.561 (8.01m)\n",
      "MSE train at step 50000: 3.457 (8.90m)\n",
      "MSE train at step 55000: 3.377 (9.79m)\n",
      "MSE train at step 60000: 3.314 (10.68m)\n",
      "MSE train at step 65000: 3.262 (11.57m)\n",
      "MSE train at step 70000: 3.219 (12.46m)\n",
      "MSE train at step 75000: 3.182 (13.35m)\n",
      "MSE train at step 80000: 3.149 (14.23m)\n",
      "MSE train at step 85000: 3.121 (15.12m)\n",
      "MSE train at step 90000: 3.095 (16.01m)\n",
      "MSE train at step 95000: 3.072 (16.90m)\n",
      "Saving model parameters and predictions ...\n",
      "DONE\n",
      "Fold # 1\n",
      "Getting data ...\n",
      "Initializing the model ...\n",
      "Train: 143852 out of 177410\n",
      "Starting model training ...\n",
      "WARNING:tensorflow:From /home/ubuntu/miniconda3/lib/python3.7/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n",
      "TF session started ...\n",
      "Starting 1st iteration ...\n",
      "MSE train at step 0: 35.183 (0.01m)\n",
      "MSE train at step 5000: 12.090 (1.14m)\n",
      "MSE train at step 10000: 5.495 (2.26m)\n",
      "MSE train at step 15000: 4.197 (3.39m)\n",
      "MSE train at step 20000: 3.734 (4.52m)\n",
      "MSE train at step 25000: 3.481 (5.65m)\n",
      "MSE train at step 30000: 3.335 (6.77m)\n",
      "MSE train at step 35000: 3.230 (7.90m)\n",
      "MSE train at step 40000: 3.150 (9.03m)\n",
      "MSE train at step 45000: 3.084 (10.16m)\n",
      "MSE train at step 50000: 3.029 (11.29m)\n",
      "MSE train at step 55000: 2.983 (12.41m)\n",
      "MSE train at step 60000: 2.942 (13.54m)\n",
      "MSE train at step 65000: 2.904 (14.67m)\n",
      "MSE train at step 70000: 2.869 (15.80m)\n",
      "MSE train at step 75000: 2.836 (16.92m)\n",
      "MSE train at step 80000: 2.802 (18.05m)\n",
      "MSE train at step 85000: 2.773 (19.18m)\n",
      "MSE train at step 90000: 2.742 (20.31m)\n",
      "MSE train at step 95000: 2.711 (21.44m)\n",
      "Saving model parameters and predictions ...\n",
      "DONE\n",
      "Fold # 2\n",
      "Getting data ...\n",
      "Initializing the model ...\n",
      "Train: 143196 out of 177636\n",
      "Starting model training ...\n",
      "TF session started ...\n",
      "Starting 1st iteration ...\n",
      "MSE train at step 0: 34.886 (0.01m)\n",
      "MSE train at step 5000: 11.873 (1.12m)\n",
      "MSE train at step 10000: 5.468 (2.24m)\n",
      "MSE train at step 15000: 4.185 (3.35m)\n",
      "MSE train at step 20000: 3.720 (4.46m)\n",
      "MSE train at step 25000: 3.458 (5.58m)\n",
      "MSE train at step 30000: 3.293 (6.69m)\n",
      "MSE train at step 35000: 3.171 (7.80m)\n",
      "MSE train at step 40000: 3.088 (8.92m)\n",
      "MSE train at step 45000: 3.025 (10.03m)\n",
      "MSE train at step 50000: 2.973 (11.15m)\n",
      "MSE train at step 55000: 2.926 (12.26m)\n",
      "MSE train at step 60000: 2.882 (13.38m)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MSE train at step 65000: 2.843 (14.49m)\n",
      "MSE train at step 70000: 2.808 (15.60m)\n",
      "MSE train at step 75000: 2.772 (16.72m)\n",
      "MSE train at step 80000: 2.740 (17.83m)\n",
      "MSE train at step 85000: 2.710 (18.94m)\n",
      "MSE train at step 90000: 2.684 (20.05m)\n",
      "MSE train at step 95000: 2.659 (21.17m)\n",
      "Saving model parameters and predictions ...\n",
      "DONE\n",
      "Fold # 3\n",
      "Getting data ...\n",
      "Initializing the model ...\n",
      "Train: 144641 out of 178088\n",
      "Starting model training ...\n",
      "TF session started ...\n",
      "Starting 1st iteration ...\n",
      "MSE train at step 0: 35.184 (0.01m)\n",
      "MSE train at step 5000: 12.229 (1.13m)\n",
      "MSE train at step 10000: 5.443 (2.24m)\n",
      "MSE train at step 15000: 4.161 (3.36m)\n",
      "MSE train at step 20000: 3.684 (4.48m)\n",
      "MSE train at step 25000: 3.442 (5.59m)\n",
      "MSE train at step 30000: 3.284 (6.71m)\n",
      "MSE train at step 35000: 3.172 (7.83m)\n",
      "MSE train at step 40000: 3.091 (8.94m)\n",
      "MSE train at step 45000: 3.027 (10.06m)\n",
      "MSE train at step 50000: 2.973 (11.18m)\n",
      "MSE train at step 55000: 2.926 (12.30m)\n",
      "MSE train at step 60000: 2.885 (13.41m)\n",
      "MSE train at step 65000: 2.850 (14.53m)\n",
      "MSE train at step 70000: 2.817 (15.64m)\n",
      "MSE train at step 75000: 2.787 (16.76m)\n",
      "MSE train at step 80000: 2.760 (17.88m)\n",
      "MSE train at step 85000: 2.732 (19.00m)\n",
      "MSE train at step 90000: 2.708 (20.11m)\n",
      "MSE train at step 95000: 2.684 (21.23m)\n",
      "Saving model parameters and predictions ...\n",
      "DONE\n",
      "Fold # 4\n",
      "Getting data ...\n",
      "Initializing the model ...\n",
      "Train: 144563 out of 178540\n",
      "Starting model training ...\n",
      "TF session started ...\n",
      "Starting 1st iteration ...\n",
      "MSE train at step 0: 35.185 (0.01m)\n",
      "MSE train at step 5000: 11.748 (1.12m)\n",
      "MSE train at step 10000: 5.421 (2.22m)\n",
      "MSE train at step 15000: 4.135 (3.32m)\n",
      "MSE train at step 20000: 3.667 (4.43m)\n",
      "MSE train at step 25000: 3.407 (5.53m)\n",
      "MSE train at step 30000: 3.248 (6.64m)\n",
      "MSE train at step 35000: 3.149 (7.74m)\n",
      "MSE train at step 40000: 3.071 (8.85m)\n",
      "MSE train at step 45000: 3.014 (9.95m)\n",
      "MSE train at step 50000: 2.965 (11.05m)\n",
      "MSE train at step 55000: 2.924 (12.16m)\n",
      "MSE train at step 60000: 2.887 (13.26m)\n",
      "MSE train at step 65000: 2.854 (14.36m)\n",
      "MSE train at step 70000: 2.822 (15.47m)\n",
      "MSE train at step 75000: 2.794 (16.57m)\n",
      "MSE train at step 80000: 2.766 (17.68m)\n",
      "MSE train at step 85000: 2.738 (18.78m)\n",
      "MSE train at step 90000: 2.713 (19.89m)\n",
      "MSE train at step 95000: 2.689 (20.99m)\n",
      "Saving model parameters and predictions ...\n",
      "DONE\n",
      "Fold # 5\n",
      "Getting data ...\n",
      "Initializing the model ...\n",
      "Train: 144704 out of 178766\n",
      "Starting model training ...\n",
      "TF session started ...\n",
      "Starting 1st iteration ...\n",
      "MSE train at step 0: 34.984 (0.01m)\n",
      "MSE train at step 5000: 12.036 (1.16m)\n",
      "MSE train at step 10000: 5.513 (2.32m)\n",
      "MSE train at step 15000: 4.191 (3.47m)\n",
      "MSE train at step 20000: 3.729 (4.62m)\n",
      "MSE train at step 25000: 3.505 (5.78m)\n",
      "MSE train at step 30000: 3.362 (6.93m)\n",
      "MSE train at step 35000: 3.258 (8.09m)\n",
      "MSE train at step 40000: 3.175 (9.24m)\n",
      "MSE train at step 45000: 3.107 (10.40m)\n",
      "MSE train at step 50000: 3.052 (11.55m)\n",
      "MSE train at step 55000: 3.000 (12.70m)\n",
      "MSE train at step 60000: 2.955 (13.86m)\n",
      "MSE train at step 65000: 2.911 (15.01m)\n",
      "MSE train at step 70000: 2.871 (16.16m)\n",
      "MSE train at step 75000: 2.834 (17.32m)\n",
      "MSE train at step 80000: 2.802 (18.47m)\n",
      "MSE train at step 85000: 2.772 (19.63m)\n",
      "MSE train at step 90000: 2.743 (20.78m)\n",
      "MSE train at step 95000: 2.718 (21.93m)\n",
      "Saving model parameters and predictions ...\n",
      "DONE\n"
     ]
    }
   ],
   "source": [
    "output_dir = '../result/cv_pred/'\n",
    "indication_specific_degree = 1\n",
    "\n",
    "for model_spec_name in ['cadrres', 'cadrres-wo-sample-bias-weight']:\n",
    "    \n",
    "    for k in range(1, 5+1):\n",
    "\n",
    "        print (\"Fold #\", k)\n",
    "\n",
    "        X_train = pd.read_csv('../preprocessed_data/GDSC/cv_data/{}_5f_{}.csv'.format('X_train', k), index_col=0)\n",
    "        Y_train = pd.read_csv('../preprocessed_data/GDSC/cv_data/{}_5f_{}.csv'.format('Y_train', k), index_col=0)\n",
    "        X_test = pd.read_csv('../preprocessed_data/GDSC/cv_data/{}_5f_{}.csv'.format('X_test', k), index_col=0)\n",
    "        Y_test = pd.read_csv('../preprocessed_data/GDSC/cv_data/{}_5f_{}.csv'.format('Y_test', k), index_col=0)\n",
    "        \n",
    "\n",
    "        #########################\n",
    "\n",
    "        ##### Prepare x0 for calculating logistic sample weigh (o_i) #####\n",
    "\n",
    "        sample_weights_logistic_x0_df = model.get_sample_weights_logistic_x0(gdsc_drug_df, 'log2_max_conc', X_train.index)\n",
    "\n",
    "        ##### Prepare indication weight (skip for this analysis = set all to 1) #####\n",
    "\n",
    "        indication_weight_df = pd.DataFrame(np.ones(Y_train.shape), index=Y_train.index, columns=Y_train.columns)\n",
    "        k_sample_list = X_train.index\n",
    "        indication_weight_df.loc[k_sample_list, :] = indication_weight_df.loc[k_sample_list, :] * indication_specific_degree\n",
    "\n",
    "        #########################\n",
    "\n",
    "        if model_spec_name in ['cadrres', 'cadrres-wo-sample-bias']:\n",
    "            cadrres_model_dict, cadrres_output_dict = model.train_model(Y_train, X_train, Y_test, X_test, 10, 0.0, 100000, 0.01, model_spec_name=model_spec_name, save_interval=5000, output_dir=output_dir)\n",
    "        elif model_spec_name in ['cadrres-wo-sample-bias-weight']:\n",
    "            cadrres_model_dict, cadrres_output_dict = model.train_model_logistic_weight(Y_train, X_train, Y_test, X_test, sample_weights_logistic_x0_df, indication_weight_df, 10, 0.0, 100000, 0.01, model_spec_name=model_spec_name, save_interval=5000, output_dir=output_dir)\n",
    "\n",
    "        #########################\n",
    "\n",
    "        ##### Save model and data #####\n",
    "        \n",
    "        pickle.dump(cadrres_model_dict, open(output_dir + '{}_5f_{}_param_dict.pickle'.format(model_spec_name, k), 'wb'))\n",
    "        pickle.dump(cadrres_output_dict, open(output_dir + '{}_5f_{}_output_dict.pickle'.format(model_spec_name, k), 'wb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-23T10:23:39.956652Z",
     "start_time": "2020-06-23T10:23:39.938970Z"
    }
   },
   "outputs": [],
   "source": [
    "cadrres_model_dict = pickle.load(open(output_dir + '{}_5f_{}_param_dict.pickle'.format(model_spec_name, 1), 'rb'))\n",
    "cadrres_output_dict = pickle.load(open(output_dir + '{}_5f_{}_output_dict.pickle'.format(model_spec_name, 1), 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-23T10:23:40.466596Z",
     "start_time": "2020-06-23T10:23:39.958405Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.collections.PathCollection at 0x7fcd4d95ebd0>"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD8CAYAAAB0IB+mAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJztnX+QFOd557/PDI2YQXfMEu85MNEKQtlQJit2xcbBRy4VuJywgyWvwRL2ibtUJWXlquyqk6LautVZZVY2PraOslHVXS4XueKy70xsJCOvkVEFnQN1riKR7d3sIrQ2XCSjXyPFWgdGsdkBZnff+2Omh56e93377Z7ume6Z51NFwc70dL87dD/99PPj+5AQAgzDMEznk2r3AhiGYZjWwAafYRimS2CDzzAM0yWwwWcYhukS2OAzDMN0CWzwGYZhuoRQDD4RfYWI3iaiFx2vjRFRgYhmqn9+P4xjMQzDMMEIy8P/KoAPSl4/IoQYqP55NqRjMQzDMAEIxeALIb4P4HIY+2IYhmGiYVnE+/80Ef17AJMAHhZCXNFt/K53vUusW7cu4iUxDMN0FlNTUz8XQvR6bUdhSSsQ0ToA3xVC/Eb153cD+DkAAeDzANYIIf5Q8rkHADwAAH19fVtfffXVUNbDMAzTLRDRlBBiyGu7yKp0hBA/E0IsCiGWAHwZwPsV2z0hhBgSQgz19nreoBiGYZiARGbwiWiN48ePAnhRtS3DMAwTPaHE8InoGwB+F8C7iOgNAAcA/C4RDaAS0nkFwB+HcSyGYRgmGKEYfCHEJyQv/0UY+2YYhmHCgTttGYZhuoSoyzIZhmkDE9MFHD51EW8WS1iby2Bk10YMD+bbvSymzbDBZ5gOY2K6gEeePo9SeREAUCiW8MjT5wGAjX6XwwafYTqMw6cu1oy9Tam8iMOnLmoNfjufCkyPzU8uzcEGn2E6jDeLJV+vA9E/FegMtemx+cmleThpyzAdxtpcxtfrgP6poFlsQ10oliBw01BPTBd8HTvKNXYL7OEzTALRecwjuzbWecIAkLHSGNm1Ubk/r6cC+3iFYglpIiwKgbxhSEVlqB88NlPbp4xCsYT1oyexKmOBCLgyX/a1due6OQRUgQ0+w8SQZkIg9nZ+DN3aXEZqeHNZC4Ofe67O2C5W9bdMQyo6g1wolkCodGfKEACKJbmhd67difPm5Ny3LlT02DOztd8xl7Ewds/mjrwxhCaeFgZDQ0NicnKy3ctgmLbiNuhAxUM/tKcfw4N5bB8/LTXO+VwGZ0d3NuzLNBnqPqaVJkAA5SW9jXAfV/Y04IXO6HvhNNCy30O13onpAsZOzEpvKFaKcPjeLbFNcrsxFU9jD59hYoYqBPLwk+cAqD3mQrGEiemC8ZOA22Dt3ZrHmQtztZ+vXl/w9K7t/dq4j2li7IGKsc/nMnizGuf3Q7FUrv1esu9Otl6vG0N5SWirmpKaQOakLcPEDJVBXxQCjzx9HrmspfysaTJ0YrqAkafO1SVSj/3wdYzs2ohL47txdnSnkbG30R3TBNvrvjS+G7mM+vdTUSov4rFnZpX5ACcEYOzErOc6daGopCaQ2eAzTMzQVdOUyosQohLiUb1vGx1dInbsxGxDqKa8JDB2YhZAxYCTjzU/eGwG60ZPGhlcNykCrl5fwPrRk9g+fhrzNxZ87wNQJ3XdmOQFAP3/Q5DS1zjABp9hYsbIro1Kgw4A75TKOLSnX/m+Xd2SIrnJXpvLKA2e/frhUxcDx9T9siQqx7WfNG4stj+vaKVIW9UUpPQ1DrDBZzqWiekCto+frnmOdtgh7gwP5nFoTz/SGoM9PJhHXmNcBOTxcwKwY5N+0NDg554L5Kl3CrmM5Zmwld2UvUpf4wAbfKYj8Wr2iTvDg3l88b4tWqPi9SQgQwD4+vOvabcxDY10ErmMhZ6sBQKw8hbvWhb7ppzPZUCo5CDsKqo4w2WZTEfip3TRi1aV38mOA+jr6Z2fCftKbqZUMmlkrLSyDDYJmJZlssFnOpL1oyelxooAXBrfbbwfr5r4sAjjOKqbXDP0ZK2O8Ph1/QCq97JWCj/+/Id8H6sd9fltH2LOMO0krKRaq8rvdPIDGx55FusM8hBBQjw6erIW/qkUrGImbqiMfcZKK9+bLy/h0Ynzvo4T91AiG3ymIwkrqdaq8jvd/txSBirjIYsrZ6xgl3iqql1j2jgVV1IetaV39q3SJr+/8YPXfR0v7vX5bPCZjiSspFqryu9M9+dlPIYH87UGprOjO3Fozx2wvKyeBA81hcTg9Xv87U8va50Avzc8nRBcHGCDz3QsbuMXJI7aqvI7P+EYP8ZjeDCPw/duCbqsjmdJVL4j3S3RT1mvqpRW9XqrYYPPMBpaVX5nH8dEVoAAXzHhpFSatAuvrmI/sXjVE0FcQmMsnsYwHjglh6M+zuFTFz3b/gXgOa7QTS5j+dLG6RYyVgqHT13EksG2JmMi8wqZaV2eoJVVPezhM0yMME0G+00aj92zOVAsvxPIWinl735ozx2+vkuvbf2GAFtd1cMGn2FihGny1m/S2I7l6zzNTqVUXsLhe7fUhct6shYe3zeA4cG8Vn3Ujdf37jcE2OqqHg7pMEyMkI0ndBM0aWyHpkyGhMSZjJXCoT13GP8OtvaQyuiahtdNv3c/IcBWq26yh88wMULmIe7f1hdq0th5jCSywko3fE86vIz0O5rcRtTJ+larbrKHzzAxoxVJYucxdKP+4kixKvXg/B1UshK5jOX5Xarm+QbRXfJLkIHzzcAePsMkhKjknocH85g5cBdeGd+Nx/cNVGbZtpHtG1Zrnz5k3u/Iro0NiVkrRRi7Z7Pn8dopddxq1U328BkmAchmqI48dQ6PPTOL4nw5tHK+w6cuotyiASQpknfCvvKPpdqQcV/er/s+ZXjfsr+zdg0kb1XZL8AGn2ESgayao7wkakqWqiHafmq8J6YLLZUAUMke2AlLP4ZYdqMqL+oHkTtppdFtJ6EYfCL6CoAPA3hbCPEb1ddWAzgGYB2AVwDcJ4S4EsbxGKbbMKnacDcGyZ4K7JsCUG9Id2zqxfGpeCg6OkM2poY4qTNmW01YHv5XAfx3AP/L8doogL8WQowT0Wj15/8U0vEYpmuYmC4gpdFzd1IolrB9/DTeLJaknymVF/HQsZm6WQGFYglHn38t0LATIvOyRhOcIRs/TyeqxGvcZ8y2mlCStkKI7wO47Hr5IwC+Vv331wAMh3EshukmbC/dVIuFgFrXpuozsleD2mwhvCWIvbA/nybC3q31vQLODtQHj81g82f/qi5ZbSeyC8VSQ8g+CTNmW02UMfx3CyHeqv77HwC8O8JjMUxHIovd21hpqotbt2skYbNSyvbnF4XA0eq83TMX5qS/99Ubixj51rnaz86QlcDN7yDf4sRrUmhJ0lYIIYhIeloQ0QMAHgCAvr6+ViyHYRKDLga97zdvw5kLcygUS9oRfklCAJ7hpfKiwIPHZqS/s23so66fTypR1uH/jIjWAED177dlGwkhnhBCDAkhhnp7eyNcDsMkD10M+syFuVoNuc7Yx0SK3RgBM/141e8cdaI2qn6IVhClh38CwB8AGK/+/Z0Ij8UwicIkITkxXcD8DfVM2UKxhLETs956Mgl0/Jt5WokyUaurfEpC+CisssxvAPhdAO8iojcAHEDF0D9JRH8E4FUA94VxLIaJE16G2/3+jk29OPnCW7X6eUBuNEwFzkzkEBJo7wPr90edqNWpW3aNwRdCfELx1r8OY/8ME0e8vD3Z+1+vJiXdlMqLePDYDA6fuoiRXRu1ydpuoLxoMpKkgrNj95Zl0arFJL3en7V0GCYgXlrmQYy2fdOIy9DrdnH1htn3ZqWpLt5fLJUjHSDSanXLsGGDzzAB8fL2gnp9pfJibIZex42Vy9N1QmMrly9D2VUXGuUAkXYKrYUBa+kwTEC8ujtV75vQCSWWYWOlCF/4aL2S5PrRk9JtowqxtFtorVnY4DNMQLy0zE2mV6nI5zJY9ysZnH3Z3cDe+agayG5dsazBsLZDUiHJQmsc0mGYgHhpmQ8P5rF3q94w7N/WpwwRHP3kB7B/W19XhXfyuYyyqqg431i1k/QQS6thD59hmsDL2ztzYU75Xj6XwcHhfgzdvloZIjg43I+Dw/3KiU6dgrM7VvW7yrz2sEIsfoTakgwbfIaJEF0s2fZCTUIEuv2sXJ42rmoxQTWYJEquXl/A+tGTdVLNpoNPmg2xJL2Zyg8c0mGYCFHFkk1mrZrsJ5/LIJddHmhtSnwa+6yVwvYNq00HTEkplso1Vcyjz7+GO/tWtWzsn1d5bSfBBp9hIkQVYzaZtWqyn5FdG0OvSDFveaqs4b/suQOv/GMptI5eAeBvXr6MkV0bcWl8N86O7ozU0056M5Uf2OAzTIQEHVLtFugCoNxPO5t+bE9YZRwJuDkc3YdwvgBa5mEnvZnKDxzDZ5iI8RtjVsWUD+3pl8r+NlP+GQZ2olOXaFUlV8dOzCo1c1rlYXuV13YSbPAZJmb4FehyGtOwKnmyVgrlJdEwGFyGbby9jKbsxvfYM7Pa/baCpDdT+YENPsPEjCAxZduYrlN0ntqYDkq5xUpjzx1rcObCXMOgc5lRD2o0ZbX1Nq30sJPcTOUHNvgMEzOa6R7Na+QcMlYae7fmjQaWX5kv4/hUoSHfoOsZ8GM07bp31Tp6sv6qmBgz2OAzTMzYsam3wSibxpRV8fyerIUDd1cqg45PvYFSub4WRyZnIAsjheEJe2n9Z6x0ba1MuLDBZ5gYMTFdwPGpQp3xJQB7t5oZWnsbZzLUaexHvnWuIS7fk7XqBrI4CZI41XWtTkwX8PCT55RhJR4+Hi1s8BkmRsgStgJ6iQYZ1xduevBX5isa8QR1ElYVCsplLWwfP403iyWsylggqsTddWMZVV2rAPDI0+eVxp6A2vCXh47NdHTytF2wwWe6nih0VILu0ythOzFdwGPPzNY88lzGwtg9m+v2raryUXFlvowDd29uCLNYacIvry3UjuUsn1TJD3h1rerWkctaXSNx0C7Y4DNdjcmYQrvc0a5w8Qo7eO3z0Ynz+MYPXseiECAAy5elah45KbSB1+YymJguNIRkiqUyRp46V/s5aGnmg8dmkMtYWGGlah781esL2rmypfIiHn7yXO33AoJ3rWasNIRovCEkaV5sEiARo0ELQ0NDYnJyst3LYLoIlTJjXlFbbpOx0nUVLE6PPqUofcxXSxtVc21V2MfSGfNcxsL1haWmm6/SKcLSkvAtk2A/aajWmK9WGOnWr7q5EIBL47t9rqi7IKIpIcSQ53Zs8Jluw2mcVWc/wXtile3x+1GXVA330JGxUlhhpZWJ1bhgl326a/Ur76UaKoNMcUonM3JMDT6HdJjYEaU2uVdJoM3aXMYzDGF78X6khIO4V6XyUmBj2UpK5UWcuTAnrfUPuv5OlThoF2zwmVjRrDa5181CllSUsWNTL06+8Fbsveq48WaxhDMX5kJRziQgUlnkboQNPhMr/OrIOJmYLmDkqXMoV13uQrFUS2h6JRXd+I2zMxVsTfuw9uXG5OmvW6ZXBYENPhMrmtEmHzsxWzP2NuUlgbETs7ULPqdpMmLih/NGb/L0p9pm8tXLdU9ssnLWboANPhMrmtGRUVV52K9PTBfY2CeMQrGkFYRzP/2pnhDdT2zOctZuMvo8AIWJFbrJTjompgue7z/sqFdnOgf7pjD4ued8hZPKSwIPHpvB9vHTnudPp8BlmUzsCBKDVdXT25jKAjPdCQG4f1sfDg73a7eLa37AtCyTPXwmdgwP5nF2dCeO7BsAADxk4IWZllAyjAwB4Ojzr2nPMTs/UKj2b9j5gSQ9HbDBZ9qKe3arffH4vbg6cf4o01q85uh66QQlAU7aMi3F+Uicy1r45bWFujJKu+pCdXH9yZMzUiXFIJIFDONG9qTo1FMy/UxcidzgE9ErAH4BYBHAgkmciQmfOMQe3SVzsooZ22NSXUR21aW7JM+vfDDDyHA/Kd7/5b/F2Zcvaz+TIsLEdCEWsXwvIk/aVg3+kBDi517bctI2GmRyAm7xr1bglVi1MdGxsfGjY8MwOgio5Y38qo6243pywklbpkZcYo+mj772E4gJbOyZMHlq8rVa7sgPSYnltyKGLwA8R0QCwJ8LIZ5owTEZB0G7V8MOA5l67YViCQ8emwl8HIYJggA8wzc6wpKUiJJWePi/LYS4E8CHAHyKiH7H+SYRPUBEk0Q0OTfHcdgoUFWw6CpboihBG9m1EVaKAn+eYeJMmvyd26oKtSiJ3MMXQhSqf79NRN8G8H4A33e8/wSAJ4BKDD/q9XQjqgqWHZt6lZ9pRsTMifspYfmyFMo3mhvSwTBxZFGI2vxfryfiZlVhgxKpwSeilQBSQohfVP99F4DPRXlMphFVBYuussVvGEgW/gHQcFIzTCdjn+NeBjwsh8ovUXv47wbwbao86iwD8JdCiL+K+JiMC5WRLhRL2D5+WuqJ6OLttphVT9bCgbs3A2g07I88fR4rrFTTI/cYJqnoDHgzqrDNEKnBF0L8FMCWKI/BeKMz3ipPRDXP1RlzuzJfViZXS+VFNvZM16My4M2owjYDl2V2ATIFSieykrLhwTwO7elHPpcBwX9CimEYtQEPqgrbLCyt0AU4tcL9tIcPD+Zrn9VpkjMM04jOgDuvyVZ2v7PB7xJs463qdtU9Sj46cT7KpTFMLCEEGzoPAHkDA+50qFoFh3S6DL+PkhPTBRxlUTKmCxEA0j77RjJWGvu39QEwk/VuNezhdxl+HyUPn7oY2MthmKSzqNHuIACrMhaIgOJ8GWtzGezY1IvjU4WW19ebwga/C/HzKJkk6VeGaSVH9g00XEfbx0+3pb7eFDb4XYiqScp+zem1pHg0IMNIkRnxdtXXm8IGP6ZEpV8va+keeeocQEB5sWLYi6WbOvVs7JluxUoRFpaEMqQpM+Ltqq83hQ1+DFHpbEy+ehlnLsz5ugm4bxxXry80PHKWDTSG7SHgzVQuMEySuHXFMuy+Y41ykprMiMsaFltRX28KG/wYoRulViov4ujzr9WMrUkyaGK6gJFvnat57s1o2SwKUTP6DNMNFOfLODjcj0tzv2yQTVYZ8XbV15sS+cQrP3TzxCvZVCoT8rkMzo7ulL43+LnnpGMEg8CePdNt9GQtZJcva8hrxc2IA+YTr9jDjwky9TwTdEOX2dgzTDCsNOGX1xZq11CxVEbGSksrc5IEN17FBK8svqr9wx1HfHTiPB46NhOaFHFP1mJjzyQePw1UaSKsXL6sIbdVKi9i7MRs2EtrKezhxwSdomVe0tABVOKIOzb11oYu5LJWaF69zbXyErJWCvPlpVD3yzCtZMlhvG1Z74eOzUidmSUh6irVnBRLZUxMF2peflTVdFHBMfw2YDIsBKgY9EN7+pUnl+wmEAVEQIxOE4ZpCvu6UhVI5DIW3imVlU+2dmy/UCw1hDsJwP3b+jB0++qW3ghMY/hs8FuMLDlrn4CAv+y+SgiNYRg9triZ7FpcYaWaflK20lSrjrP363TewoaTtjFFN9rs7OhOXydEq7r3uByT6TQKxZKyhPIhxVAfPziNPRAfeQVO2raYMFuvTbr3CMD2DavRk7Ua3rPShP3b+pDX7CdjpfGJ37pNmTRmmCRCQC0Wf3Z0Jy6N78bIro2RigXGQV6BDX6LURlpv63XE9MFXLl6veF1K03IZSwQKo+tR/YN4OgnP4Dpz96Fx/cN1CZY5XMZHP7YFhwc7seOTb3SY2SsFA7t6cfB4X6u1GE6CgHUTXmzQ626EGlP1kIu0+g4mRIHeQUO6bQY1azY+RsLddl/He4OWif7fvM2HBzul37GKY42f2MBDx2bqb0mo1Reql0UeU0VEcMkEed5r+uDcQ4zMWmQtFJUp00FxEdegQ1+i7EN+tiJ2brSryvzZalUgqyi5/Cpi1JjDwBnLsw1vOY+SZ3H9TLitoTDnX2r2OAzbSUF4AMbVjfIHDSD7WTpwi3O4glZ3H/Hpt4GjSv3NnEp1+QqnTahqrBxSiWoKnp03gUBuDS+2+hYDJMk9m/rC70M2atE07lNHAy2CtMqHY7htwmT5K2qoidN6hSqLE4Yh2QRwzTLmQtzofec2NUzstGf7m06ATb4bUKVwBFAbQ6mylCrSiStFEnjhHFIFjFMM5jkkIJWkr1ZLdG0e2FU23QCbPBDZGK6gO3jp7F+9KTn8GKdR2HHzXOSUkodt66Qp2R0x2KYJDCya6PyyTZNhHwuE7iSzHaIhgfzyhLlTnGa2OCHhLOsS+Cm0VYZ/eHBPPZuzStP4lJ5EdfKi768Fjvx6z6ml/fCMHGGUDmHVU+2i0IYeeAy/TR39YzMOYpLhU0YsMEPCV0HrYyJ6QKOTxW0Hayl8pJvr6WT4o0MAwD/csNqAFB63/lcxtMDz1hpfOm+gYZeFHcy1naOdNskGS7LDAm/HbRB9e+DrMV++mCYOJKqKpCp9Fif/+mVSu+Jx/hA93u2sFmuOrzkoWMzRiWSw4P5jjHwbtjDDwm/HbRRJoHcx4zy5sIwzbB/Wx++dN8AVmnyVYtC1BwWlfct88yP7Kt49NcXlnBlvmwUau10uA4/JHQqmDJvQVcb30xXq5WuDG94p1SuE4OKz/8ywwRDN87TibNZkQhYkpz8pvtKClyH3wZuWXbz6+zJWtrYnyo59Pi+AZwd3RlIsyNrpQBR6aS1vRk29kynYPJU7C6ekBl70311IpEbfCL6IBFdJKKXiGg06uO1A/skc0oWXPOYEOWVHBq7Z3NFk8MH1xdEw1g2NvZMp+AMVapKoE3Dl51SZumXSJO2RJQG8KcA/g2ANwD8iIhOCCF+HOVxW42uQidocsh+/bFnZo2HMbBmPdOpOJOz7vCpHZcHzD33Timz9EvUHv77AbwkhPipEOIGgG8C+EjEx2w5YWrcOxkezNfJGnuhk1xgmCSxPE3Kp1+dg2XquXdqFY4XURv8PIDXHT+/UX0tkageI8PSuFdhD2nQxfXtQSXcUct0AgIVL/zS+O6GSXA6B4u7yvW0PWlLRA8Q0SQRTc7NNUr7xgVdJ22ruvNUDnyKUBtUcmhPv3S6FcMkifKiUDYQ6hwsOzeme9bt5usj6sarAoDbHD//WvW1GkKIJwA8AVTKMiNeT2C8ZtHa20Spf11UxPKFQJ1e99iJ2VCPyzDtoFAsYfv46drQHqLKNZDLWrBSVFeg4HSwhgfz2tzXgbs3h7ZG2byKOIeLojb4PwLwHiJaj4qh/ziAfxvxMSPBK07fiu68tYr6/FzWql0Ya3OZumohhokzaSJlsQHh5oAe97Age5Sns9/Eef2pnCMgvPi9LnkcV6MfqcEXQiwQ0acBnAKQBvAVIUQi3U+VsY26vMs9mtBKU920KytN+OW1hZo3w4NOmLiScjVBWWnC4Y9tAaCWRVBRXhRYecsyzBy4S/q+6no1KX4wJWh1XjuJPIYvhHhWCPFeIcQGIcQXoj5eVEQVp9dJKrvzBsVSGRCVGKRdvbBy+bKG2nuGiRvL09TQBLVQdVxkPSkmZ7TXWMKo82pRVedFCYunGSKbZdlsvM7rkVDmQZSXBLLLl2H6sxXPZv3oycDHZ5hWcUMyg1kA+My3z9fCoc5ryWQs59pcRhlDj+J6lR2/HU/9zcBaOm3Ea67t+tGTSk8n7xhoLtsHofLILLvQGCZOvOKawQzItancvOdfrMQbV64Z61fJjtHMDcGvflaUsJZOAvB6JNR5CoViCSNPncO6X1GPSmRjzyQB2XQ4d5jHkliqv3/7qq8ZFE78DiySkUTtfPbw24jKw89lLKy8ZRkKxZJn8srrfYZJAl6e8YZHnjWWDiEAlyRPDU68nq6ThqmHzzH8NiIb6GClCFdvLNTK0LxOcTb2TCfgrm5xh1v86ESZxNCTmHANAw7ptBHZI+GtK5bVlV0yTLdgG1tZuMUU00qcqOVQ4gob/DZj6+TYmiG6hhGGSTI9WQuvjO9W1sLbxtbPhLbtG1YHiqF3+rByFRzSiQn2Iyz79kynYkdlvGbT6sIqdmdumgif+K3bcHC4P9BaWlG2GUcSb/CTpmUhw6QEjWHiQIqANauCjeB8xyGPsMJK1c73XMbC2D2ba9etrks2zIRqJw8rV5Fogx8HLQvnDSeXtSAElPoeqs+lNHoiDBMnlkRw+Q67Ucrt3FxfqJ8OJ3sCIAA7NvUGOi5zk0QbfJWWxcNPnqvbRub9h/Fk4D55nep8zpuPex07NvXi+FSh9jk29kynYxtsr2vW9ronX72Mo8+/VgtxCgDHpwoYun1113nlYZLoOvx1GlkBK0UAoa7ixa71BRrFmoJ0yJm0f+cyFq4vLPkShmKYTiRjpbVhS+c12Gl18lHTFZ22upF+5SXRUN5o1/rqVO78YFKzWyyVG47Fxp7pdGRXZqm8qL1mnddgt9bJR02iDX6QUMibxVJoJ1Mra3Z7slZXT+phkoXqylwUQjuC0EtWpNPr5KMm0QY/iLb12lwmtJOpmfmZfseNX5kv41p5yXtDhokBKk/erpVXvW9fg91aJx81iTb4OoNrpQhWuv6ksk+YsE4md6esKRkrjfu39fn+HJdtMklB5snb19jwYB5fvG9LJc/mwEpR3ZjCpAmTJYFEV+k4mycKxVKtKcOWDrbfU1XihFG/76zl9RJDU1ULPXhsxvdxGSbOuOW700SNeTK3t+P6uRvr5KMm0VU6cSOIPrZJpQ/DJAlnNZxsmHjGSmOFlZIOGecqnGCwWmYb0LVrq+r+ueqA6QQyVgrXyku1cxtoLH22KZUXleFJvh6ihT38FqCSTuB6fKZTyGWsuoHiQZ9c2cMPBnv4EaDrztW9p1L/Uxl7K0U8mJxJFMVSGRPThdo57+WpyxoSo6rC6QS9rbBgg2+ITrcHgFbTx89japoIt65YJo1vMkyccQ4wUQmgARXDPnbP5tpnojTEcdDbihNs8A3x6s5VvTc8mNee/G6WhGBNfKZl2NU0bu2aIDgdG5kAGtCojBm10dVdt2zwGSVBunPt91Qnvwy78YQrdzqDlcvTuHoj+v6JFAG3LEuh5LM5r1AsYezELIgqIUa7tDmXsXD1xoKv6WvOxsW46M17XbfLnUnSAAATMElEQVTdFu5hg2/AxHRBKWGsM9DOC6ByMeov/BQqioJPT73R3IKZ2NAKYw9Uhov85PMf0goKqig6dOrthilZyGXHpl58/fnXlPtxx9/jUEeverqWSTV3Q7iHDb4H9kkhM/Z2kmny1cvSC2HHpl5phY6qOodShGM/fJ0TtoxvbOeiJ2s1nf+xQx5nR3c2GL4zF+aUzYVxNJK66VrdGO5hg++BqsImTVRrqFKpbJ65MIczF+akaplpyRPD4pIAiycwfnFWtxy4ezNGvnXOVyhGRqFYwsBjz9W8/56sVdm3YjjJh7esiWV4RBdaekjR4d7JvQBs8D1Q/ecvCeFZgqaLw/PQE8ZKAUH18HqyForzjZPV7L8ffvJc0+eYM9RzZb6MkW+dw+GPbcHerfmG4STHfvg6jv3o9dqNJk7hEVVoSRfu6VQSLZ7WCkyUNTv5BGGioxnx0+zyZbg0vlsadhkezGOpCWOvEvQrLwo8/OQ5fF1SzaObPxFXulGRkw2+ByYnRVCZZL8SyQxj82axhEcnzmPDI89i3ehJbHjkWTw6cbMvROWEqM45qv7J5zLa0ky/Tw1xDo90oyInh3Q8MCkvc6t2miJQOcl4kDnjl+zydF2hwKIQtZ8PDvcrK2pUZ5gA8Mr4bgDhCvrF/ek3DpVErSQyLR0iGgPwSQBz1Zf+sxDiWd1nOkFLZ/3oSePmlTQRloTAqoyFGwuLmOcBJx0FATiyb6DpJKp7FmzGSuNaeVF5nuV9NPo5sQ1+EMlu2QxpK01YuXwZ3ik15hqYcInLTNsjQoiB6h+tse8U/Hg0i0JAoJIcY2MfH0xCbfbEJt22a3MZHD51sSlj35O1pGEH3R6DGPtcxsLEdAHbx08rq1dU5HMZHL53Cw5/bEttnT1ZCxCVc1vgZhJ3Yrrge21MeHBIxwCTcjN7m0KxxCqYCSZjpbF3ax7fPfdWXZWKexs71jsxXVBqvo/s2ujbeLoRojGsePjUxUpnbIgn2Ye3rDHuBreRzXpwDgNyfydeNe5xLOvsNKL28D9NRC8Q0VeIqEe2ARE9QESTRDQ5Nzcn26St2I1ThWJJ6ak4twHY2CeNnqxV5z0fHO7HzIG78Pi+gdrcZNujdyf2hgfzmP7szW3dyT/VE1+aCPu39XkOpn+nqkLpPgfDTvjL+kW82LtVHf/2K0Vicp0xzdOUh09E3wPwq5K3PgPgzwB8HhX793kAXwTwh+4NhRBPAHgCqMTwm1lPFJh046mas3IZC0SoeTq5jIUPb1mDb/zgdU7Qxojs8mWY/uxdyvcJwK+uWqH1OFXJP1Wnp31DOHNhTtsZmyKSxtPDbMa2Cwf8cnyqgKHbV4dS496NXa/toCmDL4T4PZPtiOjLAL7bzLHahYmnotrmnVIZl6qJMEAv08C0D9n/X1g6K15VXl6GNupzxSkzoJJMcOvW2+gMsk7SQEYQcULGP5HF8IlojRDireqPHwXwYlTHChtnLNFLNM3+t4k3o3oSYNqLzOtUeZyPPTPr2+OUef/2OdbKW3++KoB25sKc9OYjM9Af3rIGJ194y/dIQr9qmd3Y9doOokza/lciGkC1xBfAH0d4rNBwe3Y60TQbU2+GvZX4YaVI6nWq/q+uzJfx6MR5HBzu930s08S+uwwzDFYuT2tHB8oM9I5NvTg+VdCuRWeQ/dS4+30iYIIRmcEXQvy7qPYdJTqxtCUhPBuvdN6Mn0EojH+CVK7cumKZdMj8CkutLX/0+deUsWsVj06cb9CfkWEPJPHbxOeFiUyz20BvHz+tNfZhGuS46Od3OlyW6UInluaMx7sx8Wb8DEJh/LFyeRrzAbTni/NlabxehwB8JRMnpgtG06QIqPPCZR5vM+eOc+as7D23sdU9keYjMMjd1vXaDlhLx4WJWFpQnNodTLh89M48UuS/WNFujvJrSP2E50xj9e6JUbKGKx32tipUZY6qkshVGXnJaD6XkYq2MfGHDb6LqBX0hgfzODu6E4/vGwgkuMY0snJ5GsenCr4rWggV4xYkdOLHATC9OezY1Fv3s32uOFUxcwojDFTO3fu39SnfV6lXqhLUROg6NclOhw2+C78KenY7+vrRk9g+flrZKOLeDqg0rjDNQQCstPf4SKCSoLUbnZrphvZr9ExvDmcueDcejt2zuaJb42L/tj4MD+ZxcLgf+zVGX3bzUd2QivPlrlOT7HQ4hi/BNJZoWqv96MT5OuXCQrHkW5yq28laKQhQw6Sl+7f14ahmzqpNxkrh0J47MDyYD6QGad8ggsSuTXM3Jk8CuuSmMw4vm6gGyG8+qmICO0Tmru5hCYTkwga/CUy6A93GnglGqbyEI/sGpIZGNWfVyeqVtxg3O9noKrP8YCqfbfokoKrr91tObKO6IS0K0eDAdOPg706CQzpN4NUdODFdYGNviJUirFyuzmmszWVqMe0j+wYAAA8dm8H28dPYsanXMx/yZrFUC6uZhHIyVhpfvG+LcqqUX+y165KqQWPjE9MFPPzkOWU5sTscIwsvHtrTX9MLcuKO+6ucnLETs4HWzrQWNvhN4FXRE+fxbnGiJ2vh8L1bMPu5D2L/tr4Go+j0TGUVJcenCti7Na+tfsplrTqBOzdWmiraR4g2Vq2qfMllrEDH85LrsMuJ7ZuWqiLH3laGiYxIsSryxsQbDuk0gexRmFCptpiYLnCTlQd2DH7o9tU4fOoiHjo2g7W5DO7f1ldr/89lLQhR8eYPn7qI+RsLUg/zzIU5nB3d2RByACo3jOvlRWUjVRQ15W5sGWWZ5LKVIozdsznQfr1KSk3kPWwv3kTeQNc8yEJn8Yc9/CYYHsxj79Z8nUcqABz70esYeepcu5aVGASA7557S+qxj+zaiCP7BnCtvFQ3REOlLGl7nrIqq71b88oBM3azU9TG/pGnzyvXbnf7BsErH3H1+kKd560LQ5rObw66liCYVsExZrCHb4iqMuHMhbmGmHAzE466DZnH64wbmzZEuZuW3BIBJp+LCi8vvKiRR/bCS66jWCrXJVV1Xrzp/GbZwBd7H2HCCeLwYQ/fAN1wBhZEi4Y3iyXj79arLl63n1Y0EXn9Hs0YSplX7sZ5A/Xy4mXNXm4O3L25JQ1ZuvATEww2+AZ4xT1NsSsmwp5W1ImszWWU320uY/lqBtLtpxWeou4cadZQukNYKnQhL78J6jD2YQJr5IcPh3QM0J14R/YNGDXVOKccrRs9GcUyOwanEZQlYMfu2dx045O9n1agqnPPZSzfv4sMZwhL1VSmC3k1e8yoYI388GGDb4Bp3FMXS72+sIjJVy9z7NEDWcVMs12dptK7UXWQtlL6t5N05Tvpd4kLJGI0bm9oaEhMTk62exkNqEr93I+xsu3cbN+wGrNv/kKarOxmZN9nKzH9P04CnSR90Em/S5QQ0ZQQYshzOzb4ZpieeM6pRir2b+vDsR++jnKYk6hjjh2+sL/DVdUB78X5ciwuZFUoxJYCZpg4Y2rwOaRjiC5mKbsZ6MTRzlyYw+F7tyjL2zqRd0rlWA+44AQh0w1wlU6TqEo2vSomhgfzmP7sXXh834BUwyTO9GQtvDK+G4/vG6jJDXsR90RblINvGCYusMFvElXJZtZDCMxmeDCPT/zWbZGtLwrsRiH7piXTv3GiGhQeJ6IefMMwcYANfpOoHvnnbyxi+4bVDa/bWjs2E9MFHJ9KVru4AOra3GXdxk6akQ5oFa2qLWeYdsIx/CbRlWwe/eQH8OjE+boB1gLA8akChm5fjeHBfKB5qnGgUCzV9IK84tzNSAe0kjjnGBgmDNjDbxKvUIDM+y2VF/HYM7OJV9QsLwmMnZj1jHNzHJxh4gEb/CbxCgWovN8r8+WOUNQslsoY2bVRGcMntEavhmEYbzikEwLOUIBdomlru2eXp3H1hjxk0yl1+MODeUy+erkudAXc1LvnMAnDxAM2+CEik3PtFg4O99cGmXBXJMPEEzb4IRJ1AjZjpbF3az5WXbrOOnxOejJMvOEYfohE2ZWZJsKhPf04ONyPW1eEd58mVKQeHt834Kmr7sZKEw7c3RrFSRU8EYlhzGEPP0S8pg8FxS3iFVaZo0yeVycJ4SRNhH2/eVtbPXqeiMQw/mAPP0RMpg/5pSdrNTQAhVHmmMtYmDlwV8P4urzhvheFwPGpQls9ap6IxDD+aMrgE9G9RDRLREtENOR67xEieomILhLRruaWmQxkJZr7t/XVjKitmWO/bnJzuCYZvu3s1A2CbviHn5tWu40rC54xjD+aDem8CGAPgD93vkhE7wPwcQCbAawF8D0ieq8QInktpT7xk7i0q1oKxRIIkMoT2EbVuc8zF+YCr082YMSJ6UAXm6grkXSy1DwRiWH80ZSHL4T4iRBC5uJ9BMA3hRDXhRCXALwE4P3NHKsTsQdGvzK+G0f2DSi3c3usQT1YApSDqWXrMknkEhBZWEc3PB5gwTOG8UtUMfw8gNcdP79RfY1RoIufuz3WoB6s3885Q1QqBBBZWMcrRs+CZwzjD8+QDhF9D8CvSt76jBDiO80ugIgeAPAAAPT19TW7u1jhdzzbjk29Dd2qMo9VNRRbR1CJYmeISjV8PaqYuUmMnmv/GcYcT4MvhPi9APstAHCKvP9a9TXZ/p8A8ARQGXEY4FixxG/JoC2T7P4C7uxb1bC9/bNsYpYqFxCGRHG+xTFzjtEzTLhEFdI5AeDjRHQLEa0H8B4AP4zoWLHEb8mgqkv3b16+LI2ROydmOUMaqjtmGLX7rY6Zc4yeYcKlqSodIvoogP8GoBfASSKaEULsEkLMEtGTAH4MYAHAp7qhQseJ35JB1et2jFxXVeN8TzWMOwyv2FnB0wq9nFYfj2E6naYMvhDi2wC+rXjvCwC+0Mz+k4zfcISuS9dPjFwW3w/TK251zJxj9AwTHtxpGxF+wxE6TXk/3nlYlSusUcMwnQdr6USE33CESlM+iHferFfMGjUM05mQEPEpjBkaGhKTk5PtXkZb8VvKGQWqPEA+l8HZ0Z0tXQvDMN4Q0ZQQYshrO/bwY0YcYtasUcMwnQnH8JkGdIllhmGSCxt8pgGuf2eYzoRDOkwDXP/OMJ0JG3xGShxyCQzDhAuHdBiGYboENvgMwzBdAht8hmGYLoENPsMwTJfABp9hGKZLiJW0AhHNAXg1ot2/C8DPI9p32CRlrUlZJ8BrjYKkrBPo/LXeLoTo9dooVgY/Soho0kRrIg4kZa1JWSfAa42CpKwT4LXacEiHYRimS2CDzzAM0yV0k8F/ot0L8EFS1pqUdQK81ihIyjoBXiuALorhMwzDdDvd5OEzDMN0NR1v8InoXiKaJaIlIhpyvL6OiEpENFP98z/juM7qe48Q0UtEdJGIdrVrjTKIaIyICo7v8ffbvSY3RPTB6nf3EhGNtns9KojoFSI6X/0eYzX6jYi+QkRvE9GLjtdWE9H/IaK/r/7d08412ijWGrvzlIhuI6IzRPTj6rX/H6uvR/a9drzBB/AigD0Avi9572UhxED1z39o8brcSNdJRO8D8HEAmwF8EMD/IKJ048fbyhHH9/hsuxfjpPpd/SmADwF4H4BPVL/TuLKj+j3GrYTwq6icf05GAfy1EOI9AP66+nMc+Coa1wrE7zxdAPCwEOJ9ALYB+FT13Izse+14gy+E+IkQ4mK71+GFZp0fAfBNIcR1IcQlAC8BeH9rV5do3g/gJSHET4UQNwB8E5XvlPGBEOL7AC67Xv4IgK9V//01AMMtXZQCxVpjhxDiLSHE31X//QsAPwGQR4Tfa8cbfA/WE9E0Ef1fIvpX7V6MgjyA1x0/v1F9LU58moheqD5Kx+Kx3kESvj8bAeA5IpoiogfavRgD3i2EeKv6738A8O52LsaA2J6nRLQOwCCAHyDC77UjDD4RfY+IXpT80XlybwHoE0IMAvgTAH9JRP88hutsOx7r/jMAGwAMoPKdfrGti002vy2EuBOV8NOniOh32r0gU0Sl3C/OJX+xPU+J6FYAxwE8KIT4J+d7YX+vHTHxSgjxewE+cx3A9eq/p4joZQDvBRBZsizIOgEUANzm+PnXqq+1DNN1E9GXAXw34uX4pe3fnylCiEL177eJ6NuohKNkuae48DMiWiOEeIuI1gB4u90LUiGE+Jn97zidp0RkoWLsjwohnq6+HNn32hEefhCIqNdOfhLRrwN4D4CftndVUk4A+DgR3UJE61FZ5w/bvKYa1RPS5qOoJJ/jxI8AvIeI1hPRclQS4CfavKYGiGglEf0z+98A7kL8vks3JwD8QfXffwDgO21ci5Y4nqdERAD+AsBPhBBfcrwV3fcqhOjoP6j8576Bijf/MwCnqq/vBTALYAbA3wG4O47rrL73GQAvA7gI4EPt/k5d6/7fAM4DeKF6oq5p95oka/x9AP+v+h1+pt3rUazx1wGcq/6Zjds6AXwDlVBIuXqe/hGAX0GliuTvAXwPwOp2r1Oz1tidpwB+G5VwzQtVOzRTPVcj+16505ZhGKZL6NqQDsMwTLfBBp9hGKZLYIPPMAzTJbDBZxiG6RLY4DMMw3QJbPAZhmG6BDb4DMMwXQIbfIZhmC7h/wPfXSYTtCCf2QAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "y = cadrres_output_dict['pred_test_df'].values.flatten()\n",
    "x = cadrres_output_dict['obs_test_df'].values.flatten()\n",
    "plt.scatter(x[~np.isnan(x)], y[~np.isnan(x)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-23T10:37:29.814622Z",
     "start_time": "2020-06-23T10:37:29.693573Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cadrres_5f_1_output_dict.pickle\r\n",
      "cadrres_5f_1_param_dict.pickle\r\n",
      "cadrres_5f_2_output_dict.pickle\r\n",
      "cadrres_5f_2_param_dict.pickle\r\n",
      "cadrres_5f_3_output_dict.pickle\r\n",
      "cadrres_5f_3_param_dict.pickle\r\n",
      "cadrres_5f_4_output_dict.pickle\r\n",
      "cadrres_5f_4_param_dict.pickle\r\n",
      "cadrres_5f_5_output_dict.pickle\r\n",
      "cadrres_5f_5_param_dict.pickle\r\n",
      "cadrres-wo-sample-bias-weight_5f_1_output_dict.pickle\r\n",
      "cadrres-wo-sample-bias-weight_5f_1_param_dict.pickle\r\n",
      "cadrres-wo-sample-bias-weight_5f_2_output_dict.pickle\r\n",
      "cadrres-wo-sample-bias-weight_5f_2_param_dict.pickle\r\n",
      "cadrres-wo-sample-bias-weight_5f_3_output_dict.pickle\r\n",
      "cadrres-wo-sample-bias-weight_5f_3_param_dict.pickle\r\n",
      "cadrres-wo-sample-bias-weight_5f_4_output_dict.pickle\r\n",
      "cadrres-wo-sample-bias-weight_5f_4_param_dict.pickle\r\n",
      "cadrres-wo-sample-bias-weight_5f_5_output_dict.pickle\r\n",
      "cadrres-wo-sample-bias-weight_5f_5_param_dict.pickle\r\n"
     ]
    }
   ],
   "source": [
    "!ls {output_dir}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
